import re
from datetime import datetime
from typing import Any, Dict, List

from fastapi import WebSocket
from langchain.tools import BaseTool
from paperqa import Answer, Docs

from ..docs import (
    reduce_tokens,
    scrape,
    stream_answer,
    stream_cost,
    stream_evidence,
    stream_filter,
)
from ..models import QueryRequest, ScrapeRequest


def get_year() -> str:
    now = datetime.now()
    return now.strftime("%Y")


async def status(
    docs: Docs,
    answer: Answer,
    token_counts: Dict[str, List[int]],
    websocket: WebSocket,
):
    answer.token_counts, answer.cost = reduce_tokens(
        token_counts, docs.llm, docs.summary_llm
    )

    status = (
        f" Status: Paper Count={len(docs.docs)} | "
        f"Relevant Papers={len(answer.dockey_filter) if answer.dockey_filter  else 0} | "
        f"Current Evidence={len(answer.contexts)} | "
        f"Current Cost=${answer.cost:.2f}"
    )
    if websocket is not None:
        await websocket.send_json({"c": "agent-status", "t": status})
    return " |" + status


class GatherEvidenceTool(BaseTool):
    name = "gather_evidence"
    description = (
        "Give a specific question to get evidence for it. "
        "This will increase evidence and relevant paper counts. "
    )
    docs: Docs
    answer: Answer
    websocket: WebSocket
    query: QueryRequest
    token_counts: Dict[str, List[int]]

    def _run(self, query: str) -> str:
        raise NotImplementedError()

    async def _arun(self, query: str) -> str:
        # we update relevant papers here
        self.answer = await stream_filter(
            self.docs, query, self.answer, self.token_counts, self.websocket
        )

        question = query
        # swap out the question
        old = self.answer.question
        self.answer.question = question
        # generator, so run it
        l0 = len(self.answer.contexts)
        self.answer = await stream_evidence(
            self.docs,
            self.query,
            self.token_counts,
            websocket=self.websocket,
            answer=self.answer,
        )
        l1 = len(self.answer.contexts)
        self.answer.question = old
        sorted_contexts = sorted(
            self.answer.contexts, key=lambda x: x.score, reverse=True
        )
        best_evidence = ""
        if len(sorted_contexts) > 0:
            best_evidence = f"Best evidence:\n\n{sorted_contexts[0].context}\n\n"
        return f"Added {l1 - l0} pieces of evidence. {best_evidence}" + await status(
            self.docs, self.answer, self.token_counts, self.websocket
        )


class GenerateAnswerTool(BaseTool):
    name = "gen_answer"
    description = (
        "Ask a model to propose an answer using evidence from papers. "
        "The input is the question to be answered. "
        "The tool may fail, indicating that better or different evidence should be found."
    )
    docs: Docs
    answer: Answer
    websocket: WebSocket
    query: QueryRequest
    token_counts: Dict[str, List[int]]

    def _run(self, query: str) -> str:
        raise NotImplementedError()

    async def _arun(self, query: str) -> str:
        # TODO: Should we allow the agent to change the question?
        # self.answer.question = query
        self.answer = await stream_answer(
            self.docs, self.query, self.answer, self.token_counts, self.websocket
        )
        if "cannot answer" in self.answer.answer.lower():
            if self.query.tool_prompts.wipe_context_on_answer_failure:
                self.answer.contexts = []
                self.answer.context = ""
            return "Failed to answer question." + await status(
                self.docs,
                self.answer,
                self.token_counts,
                self.websocket,
            )

        await stream_cost(self.docs, self.answer, self.token_counts, self.websocket)
        return self.answer.answer + await status(
            self.docs, self.answer, self.token_counts, self.websocket
        )


class UpstreamCitations(BaseTool):
    name = "get_documents_that_cite"
    description = (
        "Ask a model to get the most relevant documents that cite the given paper. "
        "The input is the paper key."
    )
    docs: Docs
    answer: Answer
    websocket: WebSocket
    query: QueryRequest
    token_counts: Dict[str, List[int]]

    def _run(self, query: str) -> str:
        raise NotImplementedError()

    async def _arun(self, query: str) -> str:
        # look through context and see which one matches
        key = None
        for context in self.answer.contexts:
            if query in context.text.doc.dockey:
                key = context.text.doc.dockey
                break
            elif query in context.text.doc.docname:
                key = context.text.doc.dockey
                break
        if key is None:
            return f"Could not find paper {query} in current evidence."
        request = ScrapeRequest(
            question=self.answer.question, query=query, search_type="future_citations"
        )
        await scrape(request, self.docs, self.websocket)

        return await status(self.docs, self.answer, self.token_counts, self.websocket)


class DownstreamReferences(BaseTool):
    name = "get_paper_bibliography"
    description = (
        "Retrieve the the papers referenced in a paper's bibliography. "
        "The input is the paper key."
    )
    docs: Docs
    answer: Answer
    websocket: WebSocket
    query: QueryRequest
    token_counts: Dict[str, List[int]]

    def _run(self, query: str) -> str:
        raise NotImplementedError()

    async def _arun(self, query: str) -> str:
        # look through context and see which one matches
        key = None
        for context in self.answer.contexts:
            if query in context.text.doc.dockey:
                key = context.text.doc.dockey
                break
            elif query in context.text.doc.docname:
                key = context.text.doc.dockey
                break
        if key is None:
            return f"Could not find paper {query} in current evidence."
        request = ScrapeRequest(
            question=self.answer.question, query=query, search_type="past_references"
        )
        await scrape(request, self.docs, self.websocket)

        return await status(self.docs, self.answer, self.token_counts, self.websocket)


async def _a_scape_search(
    query: str,
    docs: Docs,
    answer: Answer,
    websocket: WebSocket,
    token_counts: Dict[str, List[int]],
    return_paper_metadata: bool = False,
) -> Any:
    # remove quotes
    query = query.replace('"', "")
    # check if years are present
    last_word = query.split(" ")[-1]
    year = None
    if re.match(r"\d{4}(-\d{4})?", last_word):
        query = query[: -len(last_word)]
        year = last_word
        if "-" not in year:
            year = year + "-" + year

    request = ScrapeRequest(question=answer.question, query=query, year=year)
    papers = await scrape(request, docs, websocket)
    state_data = await status(docs, answer, token_counts, websocket)
    if return_paper_metadata:
        return papers, state_data
    else:
        return state_data


class PaperSearchTool(BaseTool):
    name = "paper_search"
    description = (
        "Search for papers to increase the paper count. Input should be a string of keywords. "
        "Use this format: [keyword search], [start year]-[end year]. "
        "You may include years as the last word in the query, e.g. 'machine learning 2020' "
        f"or 'machine learning 2010-2020'. The current year is {get_year()}."
    )
    docs: Docs
    answer: Answer
    websocket: WebSocket
    query: QueryRequest
    token_counts: Dict[str, List[int]]
    return_paper_metadata: bool = False
    search_type: str = "google"

    def _run(self, query: str) -> str:
        raise NotImplementedError()

    async def _arun(self, query: str) -> str:
        # remove quotes
        query = query.replace('"', "")
        # check if years are present
        last_word = query.split(" ")[-1]
        year = None
        if re.match(r"\d{4}(-\d{4})?", last_word):
            query = query[: -len(last_word)]
            year = last_word
            if "-" not in year:
                year = year + "-" + year
        request = ScrapeRequest(
            question=self.answer.question,
            query=query,
            year=year,
            limit=self.query.tool_prompts.search_count,
            search_type=self.search_type,
        )
        papers = await scrape(request, self.docs, self.websocket)
        paper_data = [papers[x] for x in papers]
        paper_info_string = "Retrieved Papers:\n" + "\n".join(
            [f'{x["title"]} ({x["year"] })' for x in paper_data]
        )
        state_data = await status(
            self.docs, self.answer, self.token_counts, self.websocket
        )
        if self.return_paper_metadata:
            return paper_info_string + "\n\n" + state_data
        else:
            return state_data


class SimilarPapersTool(BaseTool):
    name = "similar_papers"
    description = (
        "Search for papers similar to the given paper. Input should be a paper key."
    )
    docs: Docs
    answer: Answer
    websocket: WebSocket
    query: QueryRequest
    token_counts: Dict[str, List[int]]

    def _run(self, query: str) -> str:
        raise NotImplementedError()

    async def _arun(self, query: str) -> str:
        # look through context and see which one matches
        key = None
        for context in self.answer.contexts:
            if query in context.text.doc.dockey:
                key = context.text.doc.dockey
                break
            elif query in context.text.doc.docname:
                key = context.text.doc.dockey
                break
        if key is None:
            return f"Could not find paper {query} in current evidence."
        request = ScrapeRequest(
            question=self.answer.question, query=query, search_type="paper"
        )
        await scrape(request, self.docs, self.websocket)

        return await status(self.docs, self.answer, self.token_counts, self.websocket)
